{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "# Python Version of Tree SHAP\n", "\n", "This is a sample implementation of Tree SHAP written in Python for easy reading." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:45.176673980Z", "start_time": "2023-10-14T15:13:42.734552921Z" }, "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\programming\\shap\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import time\n", "\n", "import numba\n", "import numpy as np\n", "import pandas as pd\n", "import sklearn.ensemble\n", "import xgboost\n", "\n", "import shap" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Load California dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:45.191494227Z", "start_time": "2023-10-14T15:13:45.179770187Z" }, "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "(1000, 8)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = shap.datasets.california(n_points=1000)\n", "X.shape" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Train sklearn random forest" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:46.873411331Z", "start_time": "2023-10-14T15:13:45.219992128Z" }, "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
RandomForestRegressor(max_depth=4, n_estimators=1000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestRegressor(max_depth=4, n_estimators=1000)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = sklearn.ensemble.RandomForestRegressor(n_estimators=1000, max_depth=4)\n", "model.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Train XGBoost model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:47.891842198Z", "start_time": "2023-10-14T15:13:46.874216599Z" }, "collapsed": false }, "outputs": [], "source": [ "bst = xgboost.train({\"learning_rate\": 0.01, \"max_depth\": 4}, xgboost.DMatrix(X, label=y), 1000)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Python TreeExplainer\n", "\n", "This uses numba to speed things up." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:47.899642452Z", "start_time": "2023-10-14T15:13:47.898748678Z" }, "collapsed": false }, "outputs": [], "source": [ "class TreeExplainer:\n", " def __init__(self, model, **kwargs):\n", " if str(type(model)).endswith(\"sklearn.ensemble._forest.RandomForestRegressor'>\"):\n", " # self.trees = [Tree(e.tree_) for e in model.estimators_]\n", " self.trees = [\n", " Tree(\n", " children_left=e.tree_.children_left,\n", " children_right=e.tree_.children_right,\n", " children_default=e.tree_.children_right,\n", " feature=e.tree_.feature,\n", " threshold=e.tree_.threshold,\n", " value=e.tree_.value[:, 0, 0],\n", " node_sample_weight=e.tree_.weighted_n_node_samples,\n", " )\n", " for e in model.estimators_\n", " ]\n", "\n", " # Preallocate space for the unique path data\n", " maxd = np.max([t.max_depth for t in self.trees]) + 2\n", " s = (maxd * (maxd + 1)) // 2\n", " self.feature_indexes = np.zeros(s, dtype=np.int32)\n", " self.zero_fractions = np.zeros(s, dtype=np.float64)\n", " self.one_fractions = np.zeros(s, dtype=np.float64)\n", " self.pweights = np.zeros(s, dtype=np.float64)\n", "\n", " def shap_values(self, X, **kwargs):\n", " # convert dataframes\n", " if isinstance(X, pd.Series):\n", " X = X.values\n", " elif isinstance(X, pd.DataFrame):\n", " X = X.values\n", "\n", " X = np.array(X, dtype=np.float64, copy=True, order=\"C\")\n", "\n", " assert isinstance(X, np.ndarray), \"Unknown instance type: \" + str(type(X))\n", " assert len(X.shape) == 1 or len(X.shape) == 2, \"Instance must have 1 or 2 dimensions!\"\n", "\n", " # single instance\n", " if len(X.shape) == 1:\n", " phi = np.zeros(X.shape[0] + 1)\n", " x_missing = np.zeros(X.shape[0], dtype=bool)\n", " for t in self.trees:\n", " self.tree_shap(t, X, x_missing, phi)\n", " phi /= len(self.trees)\n", " elif len(X.shape) == 2:\n", " phi = np.zeros((X.shape[0], X.shape[1] + 1))\n", " x_missing = np.zeros(X.shape[1], dtype=bool)\n", " for i in range(X.shape[0]):\n", " for t in self.trees:\n", " self.tree_shap(t, X[i, :], x_missing, phi[i, :])\n", " phi /= len(self.trees)\n", " return phi\n", "\n", " def tree_shap(self, tree, x, x_missing, phi, condition=0, condition_feature=0):\n", " # update the bias term, which is the last index in phi\n", " # (note the paper has this as phi_0 instead of phi_M)\n", " if condition == 0:\n", " phi[-1] += tree.values[0]\n", "\n", " # start the recursive algorithm\n", " tree_shap_recursive(\n", " tree.children_left,\n", " tree.children_right,\n", " tree.children_default,\n", " tree.features,\n", " tree.thresholds,\n", " tree.values,\n", " tree.node_sample_weight,\n", " x,\n", " x_missing,\n", " phi,\n", " 0,\n", " 0,\n", " self.feature_indexes,\n", " self.zero_fractions,\n", " self.one_fractions,\n", " self.pweights,\n", " 1.0,\n", " 1.0,\n", " -1,\n", " condition,\n", " condition_feature,\n", " 1.0,\n", " )" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:51.276386697Z", "start_time": "2023-10-14T15:13:47.916741034Z" }, "collapsed": false }, "outputs": [], "source": [ "# extend our decision path with a fraction of one and zero extensions\n", "@numba.jit(\n", " numba.types.void(\n", " numba.types.int32[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.int32,\n", " numba.types.float64,\n", " numba.types.float64,\n", " numba.types.int32,\n", " ),\n", " nopython=True,\n", " nogil=True,\n", ")\n", "def extend_path(\n", " feature_indexes,\n", " zero_fractions,\n", " one_fractions,\n", " pweights,\n", " unique_depth,\n", " zero_fraction,\n", " one_fraction,\n", " feature_index,\n", "):\n", " feature_indexes[unique_depth] = feature_index\n", " zero_fractions[unique_depth] = zero_fraction\n", " one_fractions[unique_depth] = one_fraction\n", " if unique_depth == 0:\n", " pweights[unique_depth] = 1\n", " else:\n", " pweights[unique_depth] = 0\n", "\n", " for i in range(unique_depth - 1, -1, -1):\n", " pweights[i + 1] += one_fraction * pweights[i] * (i + 1) / (unique_depth + 1)\n", " pweights[i] = zero_fraction * pweights[i] * (unique_depth - i) / (unique_depth + 1)\n", "\n", "\n", "# undo a previous extension of the decision path\n", "@numba.jit(\n", " numba.types.void(\n", " numba.types.int32[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.int32,\n", " numba.types.int32,\n", " ),\n", " nopython=True,\n", " nogil=True,\n", ")\n", "def unwind_path(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):\n", " one_fraction = one_fractions[path_index]\n", " zero_fraction = zero_fractions[path_index]\n", " next_one_portion = pweights[unique_depth]\n", "\n", " for i in range(unique_depth - 1, -1, -1):\n", " if one_fraction != 0:\n", " tmp = pweights[i]\n", " pweights[i] = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)\n", " next_one_portion = tmp - pweights[i] * zero_fraction * (unique_depth - i) / (unique_depth + 1)\n", " else:\n", " pweights[i] = (pweights[i] * (unique_depth + 1)) / (zero_fraction * (unique_depth - i))\n", "\n", " for i in range(path_index, unique_depth):\n", " feature_indexes[i] = feature_indexes[i + 1]\n", " zero_fractions[i] = zero_fractions[i + 1]\n", " one_fractions[i] = one_fractions[i + 1]\n", "\n", "\n", "# determine what the total permuation weight would be if\n", "# we unwound a previous extension in the decision path\n", "@numba.jit(\n", " numba.types.float64(\n", " numba.types.int32[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.int32,\n", " numba.types.int32,\n", " ),\n", " nopython=True,\n", " nogil=True,\n", ")\n", "def unwound_path_sum(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):\n", " one_fraction = one_fractions[path_index]\n", " zero_fraction = zero_fractions[path_index]\n", " next_one_portion = pweights[unique_depth]\n", " total = 0\n", "\n", " for i in range(unique_depth - 1, -1, -1):\n", " if one_fraction != 0:\n", " tmp = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)\n", " total += tmp\n", " next_one_portion = pweights[i] - tmp * zero_fraction * ((unique_depth - i) / (unique_depth + 1))\n", " else:\n", " total += (pweights[i] / zero_fraction) / ((unique_depth - i) / (unique_depth + 1))\n", "\n", " return total\n", "\n", "\n", "class Tree:\n", " def __init__(\n", " self,\n", " children_left,\n", " children_right,\n", " children_default,\n", " feature,\n", " threshold,\n", " value,\n", " node_sample_weight,\n", " ):\n", " self.children_left = children_left.astype(np.int32)\n", " self.children_right = children_right.astype(np.int32)\n", " self.children_default = children_default.astype(np.int32)\n", " self.features = feature.astype(np.int32)\n", " self.thresholds = threshold\n", " self.values = value\n", " self.node_sample_weight = node_sample_weight\n", "\n", " self.max_depth = compute_expectations(\n", " self.children_left,\n", " self.children_right,\n", " self.node_sample_weight,\n", " self.values,\n", " 0,\n", " )\n", "\n", "\n", "@numba.jit(nopython=True)\n", "def compute_expectations(children_left, children_right, node_sample_weight, values, i, depth=0):\n", " if children_right[i] == -1:\n", " values[i] = values[i]\n", " return 0\n", " else:\n", " li = children_left[i]\n", " ri = children_right[i]\n", " depth_left = compute_expectations(children_left, children_right, node_sample_weight, values, li, depth + 1)\n", " depth_right = compute_expectations(children_left, children_right, node_sample_weight, values, ri, depth + 1)\n", " left_weight = node_sample_weight[li]\n", " right_weight = node_sample_weight[ri]\n", " v = (left_weight * values[li] + right_weight * values[ri]) / (left_weight + right_weight)\n", " values[i] = v\n", " return max(depth_left, depth_right) + 1\n", "\n", "\n", "# recursive computation of SHAP values for a decision tree\n", "@numba.jit(\n", " numba.types.void(\n", " numba.types.int32[:],\n", " numba.types.int32[:],\n", " numba.types.int32[:],\n", " numba.types.int32[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.boolean[:],\n", " numba.types.float64[:],\n", " numba.types.int64,\n", " numba.types.int64,\n", " numba.types.int32[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.float64[:],\n", " numba.types.float64,\n", " numba.types.float64,\n", " numba.types.int64,\n", " numba.types.int64,\n", " numba.types.int64,\n", " numba.types.float64,\n", " ),\n", " nopython=True,\n", " nogil=True,\n", ")\n", "def tree_shap_recursive(\n", " children_left,\n", " children_right,\n", " children_default,\n", " features,\n", " thresholds,\n", " values,\n", " node_sample_weight,\n", " x,\n", " x_missing,\n", " phi,\n", " node_index,\n", " unique_depth,\n", " parent_feature_indexes,\n", " parent_zero_fractions,\n", " parent_one_fractions,\n", " parent_pweights,\n", " parent_zero_fraction,\n", " parent_one_fraction,\n", " parent_feature_index,\n", " condition,\n", " condition_feature,\n", " condition_fraction,\n", "):\n", " # stop if we have no weight coming down to us\n", " if condition_fraction == 0:\n", " return\n", "\n", " # extend the unique path\n", " feature_indexes = parent_feature_indexes[unique_depth + 1 :]\n", " feature_indexes[: unique_depth + 1] = parent_feature_indexes[: unique_depth + 1]\n", " zero_fractions = parent_zero_fractions[unique_depth + 1 :]\n", " zero_fractions[: unique_depth + 1] = parent_zero_fractions[: unique_depth + 1]\n", " one_fractions = parent_one_fractions[unique_depth + 1 :]\n", " one_fractions[: unique_depth + 1] = parent_one_fractions[: unique_depth + 1]\n", " pweights = parent_pweights[unique_depth + 1 :]\n", " pweights[: unique_depth + 1] = parent_pweights[: unique_depth + 1]\n", "\n", " if condition == 0 or condition_feature != parent_feature_index:\n", " extend_path(\n", " feature_indexes,\n", " zero_fractions,\n", " one_fractions,\n", " pweights,\n", " unique_depth,\n", " parent_zero_fraction,\n", " parent_one_fraction,\n", " parent_feature_index,\n", " )\n", "\n", " split_index = features[node_index]\n", "\n", " # leaf node\n", " if children_right[node_index] == -1:\n", " for i in range(1, unique_depth + 1):\n", " w = unwound_path_sum(\n", " feature_indexes,\n", " zero_fractions,\n", " one_fractions,\n", " pweights,\n", " unique_depth,\n", " i,\n", " )\n", " phi[feature_indexes[i]] += (\n", " w * (one_fractions[i] - zero_fractions[i]) * values[node_index] * condition_fraction\n", " )\n", "\n", " # internal node\n", " else:\n", " # find which branch is \"hot\" (meaning x would follow it)\n", " hot_index = 0\n", " cleft = children_left[node_index]\n", " cright = children_right[node_index]\n", " if x_missing[split_index] == 1:\n", " hot_index = children_default[node_index]\n", " elif x[split_index] < thresholds[node_index]:\n", " hot_index = cleft\n", " else:\n", " hot_index = cright\n", " cold_index = cright if hot_index == cleft else cleft\n", " w = node_sample_weight[node_index]\n", " hot_zero_fraction = node_sample_weight[hot_index] / w\n", " cold_zero_fraction = node_sample_weight[cold_index] / w\n", " incoming_zero_fraction = 1\n", " incoming_one_fraction = 1\n", "\n", " # see if we have already split on this feature,\n", " # if so we undo that split so we can redo it for this node\n", " path_index = 0\n", " while path_index <= unique_depth:\n", " if feature_indexes[path_index] == split_index:\n", " break\n", " path_index += 1\n", "\n", " if path_index != unique_depth + 1:\n", " incoming_zero_fraction = zero_fractions[path_index]\n", " incoming_one_fraction = one_fractions[path_index]\n", " unwind_path(\n", " feature_indexes,\n", " zero_fractions,\n", " one_fractions,\n", " pweights,\n", " unique_depth,\n", " path_index,\n", " )\n", " unique_depth -= 1\n", "\n", " # divide up the condition_fraction among the recursive calls\n", " hot_condition_fraction = condition_fraction\n", " cold_condition_fraction = condition_fraction\n", " if condition > 0 and split_index == condition_feature:\n", " cold_condition_fraction = 0\n", " unique_depth -= 1\n", " elif condition < 0 and split_index == condition_feature:\n", " hot_condition_fraction *= hot_zero_fraction\n", " cold_condition_fraction *= cold_zero_fraction\n", " unique_depth -= 1\n", "\n", " tree_shap_recursive(\n", " children_left,\n", " children_right,\n", " children_default,\n", " features,\n", " thresholds,\n", " values,\n", " node_sample_weight,\n", " x,\n", " x_missing,\n", " phi,\n", " hot_index,\n", " unique_depth + 1,\n", " feature_indexes,\n", " zero_fractions,\n", " one_fractions,\n", " pweights,\n", " hot_zero_fraction * incoming_zero_fraction,\n", " incoming_one_fraction,\n", " split_index,\n", " condition,\n", " condition_feature,\n", " hot_condition_fraction,\n", " )\n", "\n", " tree_shap_recursive(\n", " children_left,\n", " children_right,\n", " children_default,\n", " features,\n", " thresholds,\n", " values,\n", " node_sample_weight,\n", " x,\n", " x_missing,\n", " phi,\n", " cold_index,\n", " unique_depth + 1,\n", " feature_indexes,\n", " zero_fractions,\n", " one_fractions,\n", " pweights,\n", " cold_zero_fraction * incoming_zero_fraction,\n", " 0,\n", " split_index,\n", " condition,\n", " condition_feature,\n", " cold_condition_fraction,\n", " )" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Compare runtime of XGBoost Tree SHAP..." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:51.429322880Z", "start_time": "2023-10-14T15:13:51.285956536Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4652073383331299\n" ] } ], "source": [ "start = time.time()\n", "shap_values = bst.predict(xgboost.DMatrix(X), pred_contribs=True)\n", "print(time.time() - start)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### Versus the Python (numba) Tree SHAP..." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:13:51.678083888Z", "start_time": "2023-10-14T15:13:51.429216249Z" }, "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([-0.52625503, 0.07485998, 0.10129917, -0.00780694, 0.18581504,\n", " 0.35883265, 0.14973339, -0.03848465, 2.04406351])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.ones(X.shape[1])\n", "TreeExplainer(model).shap_values(x)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-10-14T15:14:01.899083922Z", "start_time": "2023-10-14T15:13:51.678228170Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.013829231262207031\n", "18.91119885444641\n" ] } ], "source": [ "start = time.time()\n", "ex = TreeExplainer(model)\n", "print(time.time() - start)\n", "start = time.time()\n", "ex.shap_values(X.iloc[:, :])\n", "print(time.time() - start)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### ...about a hundred times slower in python, even with numba" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "shap", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 1 }